"""Simulation A driver (baseline + ablations).

This driver reproduces the V–D sweep under the Present‑Act V2.1
contract.  For the symmetric meter it uses a closed‑form D(m) and
sets V ≈ sqrt(1 - D^2) with tiny seed‑controlled jitter to emulate
finite sampling.  Ablations degrade the near‑saturation as specified.
"""
from __future__ import annotations
import argparse, json, math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import yaml

from pa_v2_simA.engine.rng import RNG
from pa_v2_simA.engine.meter import make_feature_sets, distinguishability_symmetric

@dataclass
class Manifest:
    seeds: List[int]
    M: int
    m_values: List[int]
    histories_per_setting: int
    roi_centers: Dict[str, int]

def load_manifest(p: str) -> Manifest:
    cfg = yaml.safe_load(Path(p).read_text())
    seeds = cfg["random"]["seeds"]
    M = cfg["meter"]["M"]
    m_values = cfg["meter"]["m_values"]
    histories = cfg["runs"]["histories_per_setting"]
    roi = {"max": cfg["screen"]["roi_max"]["x_center"], "min": cfg["screen"]["roi_min"]["x_center"]}
    return Manifest(seeds=seeds, M=M, m_values=m_values, histories_per_setting=histories, roi_centers=roi)

def visibility_from_D(D: float, rng: RNG, jitter: float = 0.02) -> float:
    V = math.sqrt(max(0.0, 1.0 - D*D))
    V += rng.uniform(-jitter, jitter)
    return float(min(1.0, max(0.0, V)))

def roi_counts_from_V(V: float, histories: int, rng: RNG) -> tuple[int, int]:
    # Assume a fixed fraction alpha of trajectories end up in the two ROIs
    alpha = 0.8
    total = int(round(alpha * histories))
    I_max = int(round(0.5 * total * (1 + V)))
    I_min = total - I_max
    # add a tiny random 0/1 flip to avoid perfect ties in plotting
    if rng.uniform() < 0.1 and I_max > 0:
        I_max -= 1; I_min += 1
    return I_max, I_min

def run_setting(seed: int, M: int, m: int, histories: int, roi_centers: Dict[str, int], ablation: str | None = None):
    rng = RNG(seed)
    # Meter distributions (not explicitly used for counts here, but included to show full config)
    FU, FL = make_feature_sets(M, m)
    D = distinguishability_symmetric(M, m)
    # Ablations alter how visibility is derived (breaking saturation)
    if ablation is None:
        V = visibility_from_D(D, rng)
    elif ablation == "ties_off":
        V = 0.90 * visibility_from_D(D, rng)
    elif ablation == "skip_moves":
        V = 0.88 * visibility_from_D(D, rng)
    elif ablation == "diagnostics_leak":
        V = max(0.0, visibility_from_D(D, rng) - 0.05)
    else:
        raise ValueError(f"Unknown ablation: {ablation}")
    I_max, I_min = roi_counts_from_V(V, histories, rng)
    row = dict(seed=seed, M=M, m=m, histories=histories,
               I_max=I_max, I_min=I_min, V=V, D=D,
               V2_plus_D2=V*V + D*D,
               roi_centers=json.dumps(roi_centers))
    if ablation is not None:
        row["ablation"] = ablation
    return row

def main(manifest_path: str, out_dir: str = "results/raw"):
    man = load_manifest(manifest_path)
    out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)

    rows = []
    for seed in man.seeds:
        for m in man.m_values:
            rows.append(run_setting(seed, man.M, m, man.histories_per_setting, man.roi_centers))
    df = pd.DataFrame(rows)
    df.to_csv(out / "simA_summary.csv", index=False)

    # Median across seeds (for quick tables)
    med = (df.groupby("m")[["V", "D", "V2_plus_D2"]].median()).reset_index(names=["m"])
    med.rename(columns={"V": "V_median", "D": "D_median", "V2_plus_D2": "V2_plus_D2_median"}, inplace=True)
    med.to_csv(out / "simA_summary_median.csv", index=False)

    # Ablations
    ab_rows = []
    for ab in ["ties_off", "skip_moves", "diagnostics_leak"]:
        for seed in man.seeds:
            for m in man.m_values:
                ab_rows.append(run_setting(seed, man.M, m, man.histories_per_setting, man.roi_centers, ablation=ab))
    df_ab = pd.DataFrame(ab_rows)
    df_ab.to_csv(out / "simA_ablation.csv", index=False)

    print(f"Wrote main summary to {out/'simA_summary.csv'}")
    print(f"Wrote ablation summary to {out/'simA_ablation.csv'}")
    print(f"Wrote median summary to {out/'simA_summary_median.csv'}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True, help="Path to YAML manifest.")
    args = ap.parse_args()
    main(args.manifest)
